# -*- coding:utf-8 -*- 
# Utils module: useful functions to build exploits (more advanced features
# than Utils.py)

from ropgenerator.semantic.Engine import search, LMAX
from ropgenerator.Constraints import Constraint, RegsNotModified, Assertion, Chainable, StackPointerIncrement
from ropgenerator.semantic.ROPChains import ROPChain
from ropgenerator.Database import QueryType
from ropgenerator.exploit.Scanner import getFunctionAddress, findBytes, getSectionAddress
from ropgenerator.IO import verbose, string_bold, string_special, string_ropg
from ropgenerator.exploit.Utils import store_constant_address
import itertools
import ropgenerator.Architecture as Arch 


#### Put strings in memory

STR_TO_MEM_LMAX = 6000

def STRtoMEM(string, address, constraint, assertion, limit=None, lmax=STR_TO_MEM_LMAX\
        ,addr_str=None, hex_info=False, optimizeLen=False):
    """
    Put a string into memory 
    
    limit : if (int) then the max address where to write in memory 
            if None then string should be written ONLY at 'addr' (no adjust possible)
            
    return value: 
            a pair (address, ROPChain) or (None, None)

        
    """    
    if( not limit is None and address > limit ):
        return (None, None)
    
    chain = None
    chain2 = None
    chain3 = None
    
    # Try with memcpy 
    verbose("Trying with memcpy()")
    (addr,chain) = STRtoMEM_memcpy(string, address, constraint, assertion, limit, lmax, addr_str, hex_info)
    res = (addr,chain)
    # Try with strcpy
    verbose("Trying with strcpy()")
    if( optimizeLen or (not chain)):
        (addr2,chain2) = STRtoMEM_strcpy(string, address, constraint, assertion, limit, lmax, addr_str, hex_info)
        if( not res[1] ):
            res = (addr2,chain2)
        elif (chain2 and ( chain2 < res[1])):
            res = (addr2,chain2)
    # Try with a direct write gadget 
    verbose("Trying with gadgets only")
    if(optimizeLen or (not chain2)):
        (addr3,chain3) = STRtoMEM_write(string, address, constraint, assertion, limit, lmax, addr_str, hex_info)            
        if( not res[1] ):
            res = (addr3, chain3)
        elif (chain3 and (chain3 < res[1])):
            res = (addr3, chain3)
    
    return res

def STRtoMEM_strcpy(string, addr, constraint, assertion, limit=None, lmax=STR_TO_MEM_LMAX , addr_str=None, hex_info=False):
    """
    STRCPY STRATEGY
    Copy the string using strcpy function 
    """
    if( not addr_str ):
        addr_str = hex(addr)
    
    # Getting strcpy function 
    (func_name, func_addr ) = getFunctionAddress('strcpy')
    if( not func_addr ):
        verbose('Could not find strcpy function')
        return (None,None)
    elif( not constraint.badBytes.verifyAddress(func_addr)):
        verbose("strcpy address ({}) contains bad bytes".format(hex(func_addr)))
        return (None,None)
    
    # We decompose the string in substrings to be copied
    substrings_addr = findBytes(string, badBytes = constraint.getBadBytes(), add_null=True)
    if( not substrings_addr ):
        return (None,None)

    # Find delivery address
    substr_lengthes = [len(substr[1])-1 for substr in substrings_addr]# -1 becasue strcpy 
    substr_lengthes[-1] += 1
    if( not limit is None ):
        custom_stack = find_closest_base_fake_stack_address(addr, limit, substr_lengthes, constraint)
        if( custom_stack is None ):
            verbose("Couldn't write string in memory because of bad bytes")
            return (None,None)
    else:
        custom_stack = find_closest_base_fake_stack_address(addr, addr+sum(substr_lengthes), substr_lengthes, constraint)
        if( custom_stack is None ):
            verbose("Couldn't write string in memory because of bad bytes")
            return (None,None)
    if( custom_stack != addr ):
        addr_str = hex(custom_stack)

    # Build chain 
    res = ROPChain()
    offset = 0
    saved_custom_stack = custom_stack
    for (substring_addr,substring_str) in substrings_addr:
        if( hex_info ):
            substring_info = '\\x'+'\\x'.join(["%02x"%ord(c) for c in substring_str])
        else:
            substring_info = substring_str
        commentStack="Arg2: " + string_ropg("{} + {}".format(addr_str, offset))
        commentSubStr="Arg1: " + string_ropg(substring_info)
        func_call = build_call(func_name, [substring_addr, custom_stack], constraint, assertion, [commentSubStr, commentStack], optimizeLen=True)
        if( isinstance(func_call, str)):
            verbose("strcpy: " + func_call)
            return (None,None)
        else:
            res.addChain(func_call)
            if( len(res) > lmax ):
                return (None,None)
        # Adjust
        # -1 Because strcpy has a null byte :/
        # Except when we INTEND to write a null byte 
        if( substring_str == '\x00' ):
            dec = 0
        else:
            dec = 1
        custom_stack = custom_stack + len(substring_str) -dec
        offset = offset + len(substring_str) - dec

    return (saved_custom_stack, res)

def STRtoMEM_memcpy(string, addr, constraint, assertion, limit=None, lmax=STR_TO_MEM_LMAX , addr_str=None, hex_info=False):
    """
    MEMCPY STRATEGY
    Copy the string using memcpy function 
    """
    if( not addr_str ):
        addr_str = hex(addr)
    
    # Getting strcpy function 
    (func_name, func_addr ) = getFunctionAddress('memcpy')
    if( not func_addr ):
        verbose('Could not find memcpy function')
        return (None,None)
    elif( not constraint.badBytes.verifyAddress(func_addr)):
        verbose("memcpy address ({}) contains bad bytes".format(hex(func_addr)))
        return (None,None)
    
    # We decompose the string in substrings to be copied
    substrings_addr = findBytes(string, badBytes = constraint.getBadBytes())
    if( not substrings_addr ):
        return (None,None)

    # Find delivery address
    substr_lengthes = [len(substr[1]) for substr in substrings_addr]
    if( not limit is None ):
        custom_stack = find_closest_base_fake_stack_address(addr, limit, substr_lengthes, constraint)
        if( custom_stack is None ):
            verbose("Couldn't write string in memory because of bad bytes")
            return (None,None)
    else:
        custom_stack = find_closest_base_fake_stack_address(addr, addr+sum(substr_lengthes), substr_lengthes, constraint)
        if( custom_stack is None ):
            verbose("Couldn't write string in memory because of bad bytes")
            return (None,None)
    if( custom_stack != addr ):
        addr_str = hex(custom_stack)
        
    # Build chain 
    res = ROPChain()
    offset = 0 
    saved_custom_stack = custom_stack
    for (substring_addr,substring_str) in substrings_addr:
        if( hex_info ):
            substring_info = "'"+'\\x'+'\\x'.join(["%02x"%ord(c) for c in substring_str])+"'"
        else:
            substring_info = "'"+substring_str+"'"
        comment3 ="Arg3: " + string_ropg(str(len(substring_str)))
        comment2 ="Arg2: " + string_ropg(substring_info)
        comment1 ="Arg1: " + string_ropg("{} + {}".format(addr_str, offset))
        
        func_call = build_call(func_name, [custom_stack, substring_addr, len(substring_str)],\
                    constraint, assertion, [comment1, comment2, comment3], optimizeLen=True)
        
        if( isinstance(func_call, str) ):
            verbose("memcpy: " + func_call)
            return (None,None)
    
        res.addChain(func_call)
        if( len(res) > lmax ):
            return (None,None)
    
        # Adjust
        custom_stack = custom_stack + len(substring_str)
        offset = offset + len(substring_str)

    return (saved_custom_stack,res)
    
def STRtoMEM_write(string, addr, constraint, assertion, limit=None, lmax=STR_TO_MEM_LMAX , addr_str=None, hex_info=False):
    """
    WRITE STRATEGY
    Copy the string using mem(XXX) <- YYY gadgets 
    """
    if( not addr_str ):
        addr_str = hex(addr)
    
    # We decompose the string in substrings to be copied
    substrings_addr = find_best_valid_writes(addr, string, constraint, limit)
    if( substrings_addr is None):
        return (None, None)
    
    # Build chain 
    res = ROPChain()
    offset = 0 
    for (substring_addr,substring_val) in substrings_addr:
        substring_info = "(" + string_bold("Substring in int")+hex(substring_val) + ")"
        write_chain = store_constant_address(QueryType.CSTtoMEM, substring_addr, substring_val, constraint, assertion, clmax=lmax-len(res), optimizeLen=True)
        if( write_chain):
            res.addChain(write_chain)
        else:
            verbose("Coudln't find suitable memory write ropchain")
            return (None, None)
    return (substrings_addr[0][0], res) 
    
# Util function
def find_closest_base_fake_stack_address(base, limit, substr_lengthes, constraint):
    """
    When writing substrings, a bad address might occur for some of them 
    BASE <- SUB1
    BASE + LEN(SUB1) <- SUB2
    BASE + LEN(SUB1) + LEN(SUB2) -- BAD BYTE IN IT !! :O 
    
    So find another base address that works in the range [lower_addr..upper_addr]
    """
    # Compute the list of addresses that will be used with base and the substring lengthes 
    def get_addr_list(base, substr_lengthes):
        inc = 0
        res = [base]
        for l in substr_lengthes[:-1]: # Don't take the last one because we don't write after 
            inc += l
            res.append(base + inc)
        return res
            
    address = base
    total_length = sum(substr_lengthes)
    while(address + total_length <= limit):
        addr_list = get_addr_list(address, substr_lengthes)
        for addr in addr_list:
            index = constraint.badBytes.findIndex(addr)
            if( index >= 0 ):
                # Bad byte found 
                # If we tried everything for this byte return 
                if( (address & (0xff << index*8)) == (0xff << index*8) ):
                    return None
                # Other wise add 1 and retry 
                address += (0x1 << index*8)
                break
        # No bad bytes found in addresses, return result :) 
        # Else we keep looping 
        if( index == -1 ):
            return address
            
    # We reached upper limit to write without finding a good address
    return None

def find_best_valid_writes(addr, string, constraint, limit=None):
    """
    When using the write strategy, can have bad bytes in addresses too... 
    Try adjust it
    """
    def string_into_reg(string):
        bytes_list = [b for b in string]
        # Get base value
        if( Arch.octets() != len(bytes_list)):
            value = constraint.getValidPadding(Arch.octets()-len(bytes_list))
            if( value is None ):
                return None
        else:
            value = 0
            
        if( Arch.isLittleEndian()):
            tmp = 0
            for byte in reversed(bytes_list):
                value = (value << 8) + ord(byte)
            return value
        elif( Arch.isBigEndian()):
            tmp = 0
            for byte in bytes_list:
                tmp = (tmp << 8) + byte
            return (tmp << (8*len(bytes_list))) + value
        else:
            return None

    res = []
    tmp_addr = addr
    if( not limit ):
        limit = addr+len(string)+10
    while(tmp_addr+len(string) <= limit):
        res = []
        fail = False
        offset = 0
        while(not fail and offset < len(string)):
            # Get the next write address 
            ok = False
            for i in reversed(range(1,Arch.octets()+1)):
                if( constraint.badBytes.verifyAddress(tmp_addr+offset+i)):
                    ok = True
                    break
            if( not ok ):
                fail = True
                break
            else:
                value = string_into_reg(string[offset:i+offset])
                res.append((tmp_addr+offset, value))
                offset += i
        if( not fail ):
            return res
    return None


#### Pop values into registers 

POP_MULTIPLE_LMAX_PER_REG = 80

def popMultiple(args, constraint=None, assertion=None, clmax=None, addr=None, limit=None, optimizeLen=False):
    """
    args is a list of pairs (reg, value) 
        OR a list of triples (reg, value, comment)
    reg is a reg UID
    value is an int OR a string
    addr and limit are used to put strings if args contains strings 
    Creates a chain that pops values into regs
    """
    if( clmax is None ):
        clmax = POP_MULTIPLE_LMAX_PER_REG*len(args)
    elif( clmax <= 0 ):
        return None
    for arg in args:
        if( isinstance(arg, str)):
            clmax += STR_TO_MEM_LMAX
    
    if( constraint is None ):
        constr = Constraint()
    else:
        constr = constraint
    if( assertion is None ):
        a = Assertion()
    else:
        a = assertion
      
    # Get address
    #Find address for the payload 
    if( not addr ):
        # Get the .bss address 
        addr = getSectionAddress('.bss')
        if( not addr ):
            return None
        
    perms = itertools.permutations(args)
    for perm in perms:
        clmax_tmp = clmax 
        res = ROPChain()
        constr_tmp = constr
        tmp_addr = addr
        chains = None
        for arg in perm:
            if( len(arg) == 3 ):
                comment = arg[2]
            else:
                comment = None
            if( isinstance(arg[1], int)):
                chains = search(QueryType.CSTtoREG, arg[0], arg[1], constr_tmp, a, n=1, clmax=clmax_tmp, CSTtoREG_comment=comment, optimizeLen=optimizeLen)
            elif( isinstance(arg[1], str)):
                (address, str_to_mem) = STRtoMEM(arg[1], tmp_addr, constr_tmp, a, limit=limit, lmax=clmax_tmp, addr_str=comment, hex_info=True, optimizeLen=optimizeLen)
                if( not str_to_mem ):
                    break
                tmp_addr = address + len(arg)
                pop = search(QueryType.CSTtoREG, arg[0], address, constr_tmp, a, n=1, clmax=clmax_tmp-len(str_to_mem), optimizeLen=optimizeLen)
                if( not pop ):
                    break
                chains = [str_to_mem.addChain(pop[0])]
            else:
                raise Exception("UNknown argument type in popMultiple: '{}'".format(type(arg)))
                
            if( not chains ):
                break
            else:
                clmax_tmp -= len(chains[0])
                # If Reached max length, exit  
                if( clmax_tmp < 0 ):
                    chains = None
                    break
                else:
                    res.addChain(chains[0])
                    constr_tmp = constr_tmp.add(RegsNotModified([arg[0]]))
        if( chains ):
            return res
    return None


## Call functions
def build_call(funcName, funcArgs, constraint, assertion, argsDescription=None, clmax=None, optimizeLen=False):
    """
    funcArgs : list of pairs (arg_value, arg_description)
    """
    # Merge description and args 
    if( argsDescription ):
        funcArgs = zip(funcArgs, argsDescription)
    else:
        funcArgs = [(arg,) for arg in funcArgs]
    
    if( Arch.currentBinType == Arch.BinaryType.X86_ELF ):
        return build_call_linux86(funcName, funcArgs, constraint, assertion, clmax, optimizeLen)
    elif( Arch.currentBinType == Arch.BinaryType.X64_ELF ):
        return build_call_linux64(funcName, funcArgs, constraint, assertion, clmax, optimizeLen)
    return []
        
def build_call_linux64(funcName, funcArgs, constraint, assertion, clmax=None, optimizeLen=False):
    # Arguments registers 
    # (Args should go in these registers for x64)
    argsRegsNames = ['rdi','rsi','rdx','rcx', 'r8', 'r9']
    argsRegs = [Arch.n2r(name) for name in argsRegsNames]
    # Find the address of the fonction 
    (funcName2, funcAddr) = getFunctionAddress(funcName)
    if( funcName2 is None ):
        return "Couldn't find function '{}' in the binary".format(funcName)
    
    # Check if bad bytes in function address 
    if( not constraint.badBytes.verifyAddress(funcAddr) ):
        return "'{}' address ({}) contains bad bytes".format(funcName2, string_special('0x'+format(funcAddr, '0'+str(Arch.octets()*2)+'x')))
    
    # Check how many arguments 
    if( len(funcArgs) > 6 ):
        return "Doesn't support function call with more than 6 arguments with Linux X64 calling convention :("
        
    # Find a gadget for the fake return address
    if( funcArgs ):
        # Build the ropchain with the arguments
        args_chain = popMultiple(map(lambda x,y:(x,)+y,  argsRegs[:len(funcArgs)], funcArgs), constraint, assertion, clmax=clmax, optimizeLen=optimizeLen)
        if( not args_chain):
            return "Couldn't load arguments in registers"
    else:
        # No arguments 
        args_chain = ROPChain()
    
    # Build call chain (function address + fake return address)
    return args_chain.addPadding(funcAddr, comment=string_ropg(funcName2))
    

def build_call_linux86(funcName, funcArgs, constraint, assertion, clmax=None, optimizeLen=False):
    # Find the address of the fonction 
    (funcName2, funcAddr) = getFunctionAddress(funcName)
    if( funcName2 is None ):
        return "Couldn't find function '{}' in the binary".format(funcName)
    
    # Check if bad bytes in function address 
    if( not constraint.badBytes.verifyAddress(funcAddr) ):
        return "'{}' address ({}) contains bad bytes".format(funcName2, string_special('0x'+format(funcAddr, '0'+str(Arch.octets()*2)+'x')))
    
    # Check if lmax too small
    if( (1 + len(funcArgs) + (lambda x: 1 if len(x)>0 else 0)(funcArgs)) > clmax ):
        return "Not enough bytes to call function '{}'".format(funcName)
    
    # Find a gadget for the fake return address
    if( funcArgs ):
        offset = (len(funcArgs)-1)*Arch.octets() # Because we do +octets() at the beginning of the loop
        skip_args_chains = []
        i = 4 # Try 4 more maximum 
        while( i > 0 and (not skip_args_chains)):
            offset += Arch.octets() 
            skip_args_chains = search(QueryType.MEMtoREG, Arch.ipNum(), \
                        (Arch.spNum(),offset), constraint, assertion, n=1, optimizeLen=optimizeLen)
            i -= 1
            
        if( not skip_args_chains ):
            return "Couldn't build ROP-Chain"
        skip_args_chain = skip_args_chains[0]
    else:
        # No arguments 
        skip_args_chain = None
    
    # Build the ropchain with the arguments 
    args_chain = ROPChain()
    arg_n = len(funcArgs)
    for arg in reversed(funcArgs):
        if( isinstance(arg, int) ):
            args_chain.addPadding(arg, comment="Arg{}: {}".format(arg_n, string_ropg(hex(arg))))
            arg_n -= 1
        else:
            return "Type of argument '{}' not supported yet :'(".format(arg)
    
    # Build call chain (function address + fake return address)
    call_chain = ROPChain()
    call_chain.addPadding(funcAddr, comment=string_ropg(funcName2))
    if( funcArgs ):
        skip_args_addr = int( validAddrStr(skip_args_chain.chain[0], constraint.getBadBytes(), Arch.bits())  ,16)
        call_chain.addPadding(skip_args_addr, comment="Address of: "+string_bold(str(skip_args_chain.chain[0])))
    
    return call_chain.addChain(args_chain) 
